{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dd72df66",
   "metadata": {},
   "source": [
    "# Finetuning Quantized Llama models with _Adapters_\n",
    "\n",
    "In this notebook, we show how to efficiently fine-tune a quantized **Llama 2** or **Llama 3** model using [**QLoRA** (Dettmers et al., 2023)](https://arxiv.org/abs/2305.14314) and the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library.\n",
    "\n",
    "For this example, we finetune Llama-2 7B/ Llama-3 8B on supervised instruction tuning data collected by the [Open Assistant project](https://github.com/LAION-AI/Open-Assistant) for training chatbots. This is similar to the setup used to train the Guanaco models in the QLoRA paper.\n",
    "You can simply replace this with any of your own domain-specific data!\n",
    "\n",
    "Additionally, you can quickly adapt this notebook to use other **adapter methods such as bottleneck adapters or prefix tuning.**\n",
    "\n",
    "Pre-trained checkpoints based on this notebook can be found on HuggingFace Hub:\n",
    "- for Llama-2 7B: [AdapterHub/llama2-7b-qlora-openassistant](https://huggingface.co/AdapterHub/llama2-7b-qlora-openassistant)\n",
    "- for Llama-2 13B: [AdapterHub/llama2-13b-qlora-openassistant](https://huggingface.co/AdapterHub/llama2-13b-qlora-openassistant)\n",
    "- for Llama-2 7B with sequential bottleneck adapter: [AdapterHub/llama2-7b-qadapter-seq-openassistant](https://huggingface.co/AdapterHub/llama2-7b-qadapter-seq-openassistant)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2aa993e5",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "Besides `adapters`, we require `bitsandbytes` for quantization and `accelerate` for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7e7bec93-c233-4493-b5d7-23d06f02d218",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -qq -U adapters accelerate bitsandbytes datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3ad77898",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e64aabb0-25ef-455f-9494-b6f4ca3ecfc9",
   "metadata": {},
   "source": [
    "## Load Open Assistant dataset\n",
    "\n",
    "We use the [`timdettmers/openassistant-guanaco`](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) dataset by the QLoRA, which contains a small subset of conversations from the full Open Assistant database and was also used to finetune the Guanaco models in the QLoRA paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fe883fe8-868c-4cc8-92e5-ed9889143ae7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/adapters/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Repo card metadata block was not found. Setting CardData to empty.\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"timdettmers/openassistant-guanaco\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bb2c4bf",
   "metadata": {},
   "source": [
    "Our training dataset has roughly 10k training samples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7aab5773-415c-4c90-904a-f5c0a755abfa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text'],\n",
       "        num_rows: 9846\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text'],\n",
       "        num_rows: 518\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f2b666c7-108b-4a64-aaec-64dfb1b10078",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "### Human: Can you write a short introduction about the relevance of the term \"monopsony\" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant: \"Monopsony\" refers to a market structure where there is only one buyer for a particular good or service. In economics, this term is particularly relevant in the labor market, where a monopsony employer has significant power over the wages and working conditions of their employees. The presence of a monopsony can result in lower wages and reduced employment opportunities for workers, as the employer has little incentive to increase wages or provide better working conditions.\n",
      "\n",
      "Recent research has identified potential monopsonies in industries such as retail and fast food, where a few large companies control a significant portion of the market (Bivens & Mishel, 2013). In these industries, workers often face low wages, limited benefits, and reduced bargaining power, leading to a situation where they are dependent on the employer for their livelihood. This dependence can result in further suppression of wages and a decline in working conditions.\n",
      "\n",
      "Overall, the concept of monopsony is essential to understanding the dynamics of labor markets and the impact of market power on workers. Further research is needed to understand the extent and impact of monopsonies on the economy and to develop policies to address this issue.\n",
      "\n",
      "References:\n",
      "Bivens, J., & Mishel, L. (2013). The Pay of Corporate Executives and Financial Professionals as Evidence of Rents in Top 1 Percent Incomes. Journal of Economic Perspectives, 27(3), 57-78.### Human: Now explain it to a dog\n"
     ]
    }
   ],
   "source": [
    "print(dataset[\"train\"][0][\"text\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c851ab41-4a7b-4df1-8140-48623daeae99",
   "metadata": {},
   "source": [
    "## Load and prepare model and tokenizer\n",
    "\n",
    "We download the the official Llama-2 7B/ Llama-3 8B checkpoint from the HuggingFace Hub (**Note:** You must request access to this model on the HuggingFace website and use an API token to download it.).\n",
    "\n",
    "Via the `BitsAndBytesConfig`, we specify that the model should be loaded in 4bit quantization and with double quantization for even better memory efficiency. See [their documentation](https://huggingface.co/docs/bitsandbytes/main/en/index) for more on this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a8ef7f26-f87b-4c54-924f-c9661bc1bf2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading shards: 100%|██████████| 4/4 [01:28<00:00, 22.22s/it]\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.38s/it]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig\n",
    "\n",
    "# modelpath=\"meta-llama/Llama-2-7b-hf\"\n",
    "modelpath=\"meta-llama/Meta-Llama-3-8B\"\n",
    "\n",
    "# Load 4-bit quantized model\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    modelpath,    \n",
    "    device_map=\"auto\",\n",
    "    quantization_config=BitsAndBytesConfig(\n",
    "        load_in_4bit=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ),\n",
    "    torch_dtype=torch.bfloat16,\n",
    ")\n",
    "model.config.use_cache = False\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(modelpath)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d6f69bb",
   "metadata": {},
   "source": [
    "We initialize the adapter functionality in the loaded model via `adapters.init()` and add a new LoRA adapter (named `\"assistant_adapter\"`) via `add_adapter()`.\n",
    "\n",
    "In the call to `LoRAConfig()`, you can configure how and where LoRA layers are added to the model. Here, we want to add LoRA layers to all linear projections of the self-attention modules (`attn_matrices=[\"q\", \"k\", \"v\"]`) as well as intermediate and outputa linear layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6897387e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Name                     Architecture         #Param      %Param  Active   Train\n",
      "--------------------------------------------------------------------------------\n",
      "assistant_adapter        lora            113,246,208       2.820       1       1\n",
      "--------------------------------------------------------------------------------\n",
      "Full model                              4,015,263,744     100.000               0\n",
      "================================================================================\n"
     ]
    }
   ],
   "source": [
    "import adapters\n",
    "from adapters import LoRAConfig\n",
    "\n",
    "adapters.init(model)\n",
    "\n",
    "config = LoRAConfig(\n",
    "    selfattn_lora=True, intermediate_lora=True, output_lora=True,\n",
    "    attn_matrices=[\"q\", \"k\", \"v\"],\n",
    "    alpha=16, r=64, dropout=0.1\n",
    ")\n",
    "model.add_adapter(\"assistant_adapter\", config=config)\n",
    "model.train_adapter(\"assistant_adapter\")\n",
    "\n",
    "print(model.adapter_summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9eec8d5",
   "metadata": {},
   "source": [
    "To correctly train bottleneck adapters or prefix tuning, uncomment the following lines to move the adapter weights to GPU explicitly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca0c26e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.adapter_to(\"assistant_adapter\", device=\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19b25d99",
   "metadata": {},
   "source": [
    "Some final preparations for 4bit training: we cast a few parameters to float32 for stability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "75ff7deb",
   "metadata": {},
   "outputs": [],
   "source": [
    "for param in model.parameters():\n",
    "    if param.ndim == 1:\n",
    "        # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
    "        param.data = param.data.to(torch.float32)\n",
    "\n",
    "# Enable gradient checkpointing to reduce required memory if needed\n",
    "# model.gradient_checkpointing_enable()\n",
    "# model.enable_input_require_grads()\n",
    "\n",
    "class CastOutputToFloat(torch.nn.Sequential):\n",
    "    def forward(self, x): return super().forward(x).to(torch.float32)\n",
    "model.lm_head = CastOutputToFloat(model.lm_head)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e7211c68-07a3-4d24-a7af-a3691063a758",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(128256, 4096)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayerWithAdapters(\n",
       "        (self_attn): LlamaSdpaAttentionWithAdapters(\n",
       "          (q_proj): LoRALinear4bit(\n",
       "            in_features=4096, out_features=4096, bias=False\n",
       "            (loras): ModuleDict(\n",
       "              (assistant_adapter): LoRA(\n",
       "                (lora_dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "          )\n",
       "          (k_proj): LoRALinear4bit(\n",
       "            in_features=4096, out_features=1024, bias=False\n",
       "            (loras): ModuleDict(\n",
       "              (assistant_adapter): LoRA(\n",
       "                (lora_dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "          )\n",
       "          (v_proj): LoRALinear4bit(\n",
       "            in_features=4096, out_features=1024, bias=False\n",
       "            (loras): ModuleDict(\n",
       "              (assistant_adapter): LoRA(\n",
       "                (lora_dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "          )\n",
       "          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
       "          (rotary_emb): LlamaRotaryEmbedding()\n",
       "          (prefix_tuning): PrefixTuningLayer(\n",
       "            (prefix_gates): ModuleDict()\n",
       "            (pool): PrefixTuningPool(\n",
       "              (prefix_tunings): ModuleDict()\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): LoRALinear4bit(\n",
       "            in_features=4096, out_features=14336, bias=False\n",
       "            (loras): ModuleDict(\n",
       "              (assistant_adapter): LoRA(\n",
       "                (lora_dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "          )\n",
       "          (down_proj): LoRALinear4bit(\n",
       "            in_features=14336, out_features=4096, bias=False\n",
       "            (loras): ModuleDict(\n",
       "              (assistant_adapter): LoRA(\n",
       "                (lora_dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "          )\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm()\n",
       "        (post_attention_layernorm): LlamaRMSNorm()\n",
       "        (attention_adapters): BottleneckLayer(\n",
       "          (adapters): ModuleDict()\n",
       "          (adapter_fusion_layer): ModuleDict()\n",
       "        )\n",
       "        (output_adapters): BottleneckLayer(\n",
       "          (adapters): ModuleDict()\n",
       "          (adapter_fusion_layer): ModuleDict()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm()\n",
       "    (shared_parameters): ModuleDict()\n",
       "    (invertible_adapters): ModuleDict()\n",
       "    (prefix_tuning): PrefixTuningPool(\n",
       "      (prefix_tunings): ModuleDict()\n",
       "    )\n",
       "  )\n",
       "  (lm_head): CastOutputToFloat(\n",
       "    (0): Linear(in_features=4096, out_features=128256, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "24289126",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.bfloat16 1050673152 0.22576446079143245\n",
      "torch.uint8 3489660928 0.749844436640606\n",
      "torch.float32 113512448 0.024391102567961606\n"
     ]
    }
   ],
   "source": [
    "# Verifying the datatypes.\n",
    "dtypes = {}\n",
    "for _, p in model.named_parameters():\n",
    "    dtype = p.dtype\n",
    "    if dtype not in dtypes:\n",
    "        dtypes[dtype] = 0\n",
    "    dtypes[dtype] += p.numel()\n",
    "total = 0\n",
    "for k, v in dtypes.items():\n",
    "    total += v\n",
    "for k, v in dtypes.items():\n",
    "    print(k, v, v / total)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9913bc18-8a26-4a1b-8abc-3bc8e672f191",
   "metadata": {},
   "source": [
    "## Prepare data for training\n",
    "\n",
    "The dataset is tokenized and truncated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "92a941dc-9e9e-4bbd-9c2a-a54f2fd72071",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map (num_proc=24): 100%|██████████| 9846/9846 [00:01<00:00, 9098.38 examples/s] \n",
      "Map (num_proc=24): 100%|██████████| 518/518 [00:00<00:00, 1953.76 examples/s]\n"
     ]
    }
   ],
   "source": [
    "import os \n",
    "\n",
    "def tokenize(element):\n",
    "    return tokenizer(\n",
    "        element[\"text\"],\n",
    "        truncation=True,\n",
    "        max_length=512, # can set to longer values such as 2048\n",
    "        add_special_tokens=False,\n",
    "    )\n",
    "\n",
    "dataset_tokenized = dataset.map(\n",
    "    tokenize, \n",
    "    batched=True, \n",
    "    num_proc=os.cpu_count(),    # multithreaded\n",
    "    remove_columns=[\"text\"]     # don't need this anymore, we have tokens from here on\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c427e534-2214-4bc6-8c73-4a81c4984db5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'attention_mask'],\n",
       "        num_rows: 9846\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input_ids', 'attention_mask'],\n",
       "        num_rows: 518\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_tokenized"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27287f0e-56f0-458f-8c2e-9124a9739d48",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "We specify training hyperparameters and train the model using the `AdapterTrainer` class.\n",
    "\n",
    "The hyperparameters here are similar to those chosen [in the official QLoRA repo](https://github.com/artidoro/qlora/blob/main/scripts/finetune_llama2_guanaco_7b.sh), but feel free to configure as you wish!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e36b9ad9-d0c0-4dc7-a4c8-03765891dac6",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"output/llama_qlora\",\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    logging_steps=10,\n",
    "    save_steps=500,\n",
    "    eval_steps=187,\n",
    "    save_total_limit=3,\n",
    "    gradient_accumulation_steps=16,\n",
    "    max_steps=1875,\n",
    "    lr_scheduler_type=\"constant\",\n",
    "    optim=\"paged_adamw_32bit\",\n",
    "    learning_rate=0.0002,\n",
    "    group_by_length=True,\n",
    "    bf16=True,\n",
    "    warmup_ratio=0.03,\n",
    "    max_grad_norm=0.3,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5d8b96-bcfe-490d-8462-dfb44d575432",
   "metadata": {},
   "outputs": [],
   "source": [
    "from adapters import AdapterTrainer\n",
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "trainer = AdapterTrainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
    "    train_dataset=dataset_tokenized[\"train\"],\n",
    "    eval_dataset=dataset_tokenized[\"test\"],\n",
    "    args=args,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0efbf7c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe958104",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "Finally, we can prompt the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "663e9b66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ignore warnings\n",
    "from transformers import logging\n",
    "logging.set_verbosity(logging.CRITICAL)\n",
    "\n",
    "def prompt_model(model, text: str):\n",
    "    batch = tokenizer(f\"### Human: {text}\\n### Assistant:\", return_tensors=\"pt\")\n",
    "    batch = batch.to(model.device)\n",
    "    \n",
    "    model.eval()\n",
    "    with torch.inference_mode(), torch.cuda.amp.autocast():\n",
    "        output_tokens = model.generate(**batch, max_new_tokens=50)\n",
    "\n",
    "    return tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fe0e9c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(prompt_model(model, \"Explain Calculus to a primary school student\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcb96af2",
   "metadata": {},
   "source": [
    "## Merge LoRA weights\n",
    "\n",
    "For lower inference latency, the LoRA weights can be merged with the base model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5176ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.merge_adapter(\"assistant_adapter\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "403f84a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(prompt_model(model, \"Explain NLP in simple terms\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
